# -*- coding:utf-8 -*-

from tensorflow.examples.tutorials.mnist import input_data
from dto.devide_dataset_dto import DividedDataSetDto

"""
加载mnist
"""


def read(ds_dir):
    """
    加载原始数据集 train_set, validate_set, test_set
    :return:
    """
    # 加载mnist原始数据集
    dataset_ori = input_data.read_data_sets(ds_dir, one_hot=True)
    train_set, validate_set, test_set = dataset_ori.train, dataset_ori.validation, dataset_ori.test
    # 构建传参dto
    divided_dataset = DividedDataSetDto(
        train_set.images, train_set.labels, train_set.num_examples,
        validate_set.images, validate_set.labels, validate_set.num_examples,
        test_set.images, test_set.labels, test_set.num_examples
    )
    return divided_dataset
